{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 683.8283081054688\n",
      "1 683.3124389648438\n",
      "2 682.7969970703125\n",
      "3 682.2822875976562\n",
      "4 681.7678833007812\n",
      "5 681.25439453125\n",
      "6 680.7418823242188\n",
      "7 680.2299194335938\n",
      "8 679.7191772460938\n",
      "9 679.2091674804688\n",
      "10 678.6995239257812\n",
      "11 678.1908569335938\n",
      "12 677.6829833984375\n",
      "13 677.1756591796875\n",
      "14 676.669189453125\n",
      "15 676.16357421875\n",
      "16 675.6588745117188\n",
      "17 675.1555786132812\n",
      "18 674.6533203125\n",
      "19 674.151123046875\n",
      "20 673.650146484375\n",
      "21 673.1502075195312\n",
      "22 672.6503295898438\n",
      "23 672.151611328125\n",
      "24 671.6534423828125\n",
      "25 671.156494140625\n",
      "26 670.6597290039062\n",
      "27 670.1637573242188\n",
      "28 669.66845703125\n",
      "29 669.1732788085938\n",
      "30 668.6795043945312\n",
      "31 668.1871948242188\n",
      "32 667.6953735351562\n",
      "33 667.2039184570312\n",
      "34 666.7141723632812\n",
      "35 666.2247924804688\n",
      "36 665.7361450195312\n",
      "37 665.2484741210938\n",
      "38 664.76123046875\n",
      "39 664.2737426757812\n",
      "40 663.7880249023438\n",
      "41 663.3018798828125\n",
      "42 662.8174438476562\n",
      "43 662.3321533203125\n",
      "44 661.849365234375\n",
      "45 661.3671264648438\n",
      "46 660.8865356445312\n",
      "47 660.4055786132812\n",
      "48 659.9256591796875\n",
      "49 659.4485473632812\n",
      "50 658.9702758789062\n",
      "51 658.4944458007812\n",
      "52 658.0178833007812\n",
      "53 657.5425415039062\n",
      "54 657.0681762695312\n",
      "55 656.593505859375\n",
      "56 656.1196899414062\n",
      "57 655.6469116210938\n",
      "58 655.17529296875\n",
      "59 654.703857421875\n",
      "60 654.2332153320312\n",
      "61 653.7639770507812\n",
      "62 653.295166015625\n",
      "63 652.82666015625\n",
      "64 652.3593139648438\n",
      "65 651.892822265625\n",
      "66 651.4276123046875\n",
      "67 650.9630126953125\n",
      "68 650.4983520507812\n",
      "69 650.035400390625\n",
      "70 649.5724487304688\n",
      "71 649.1101684570312\n",
      "72 648.648681640625\n",
      "73 648.1878051757812\n",
      "74 647.7286987304688\n",
      "75 647.268798828125\n",
      "76 646.8103637695312\n",
      "77 646.3516845703125\n",
      "78 645.8939208984375\n",
      "79 645.4366455078125\n",
      "80 644.9796752929688\n",
      "81 644.5238647460938\n",
      "82 644.0682983398438\n",
      "83 643.6133422851562\n",
      "84 643.1592407226562\n",
      "85 642.7056274414062\n",
      "86 642.2522583007812\n",
      "87 641.7996215820312\n",
      "88 641.3473510742188\n",
      "89 640.8960571289062\n",
      "90 640.4456787109375\n",
      "91 639.9957275390625\n",
      "92 639.5453491210938\n",
      "93 639.09716796875\n",
      "94 638.6486206054688\n",
      "95 638.2008056640625\n",
      "96 637.7540893554688\n",
      "97 637.3070678710938\n",
      "98 636.86083984375\n",
      "99 636.4157104492188\n",
      "100 635.970947265625\n",
      "101 635.5263671875\n",
      "102 635.0825805664062\n",
      "103 634.6395874023438\n",
      "104 634.19677734375\n",
      "105 633.7551879882812\n",
      "106 633.3141479492188\n",
      "107 632.874267578125\n",
      "108 632.4343872070312\n",
      "109 631.9948120117188\n",
      "110 631.5556030273438\n",
      "111 631.1172485351562\n",
      "112 630.6792602539062\n",
      "113 630.241943359375\n",
      "114 629.8052978515625\n",
      "115 629.3685913085938\n",
      "116 628.9327392578125\n",
      "117 628.4976806640625\n",
      "118 628.0630493164062\n",
      "119 627.6289672851562\n",
      "120 627.19580078125\n",
      "121 626.7635498046875\n",
      "122 626.3309936523438\n",
      "123 625.8993530273438\n",
      "124 625.4681396484375\n",
      "125 625.037841796875\n",
      "126 624.607177734375\n",
      "127 624.1779174804688\n",
      "128 623.7487182617188\n",
      "129 623.3195190429688\n",
      "130 622.891357421875\n",
      "131 622.4630126953125\n",
      "132 622.0357055664062\n",
      "133 621.6083374023438\n",
      "134 621.181396484375\n",
      "135 620.7551879882812\n",
      "136 620.3294067382812\n",
      "137 619.9039306640625\n",
      "138 619.4794921875\n",
      "139 619.0555419921875\n",
      "140 618.6326293945312\n",
      "141 618.2092895507812\n",
      "142 617.7877807617188\n",
      "143 617.365966796875\n",
      "144 616.945068359375\n",
      "145 616.5245361328125\n",
      "146 616.104736328125\n",
      "147 615.6851196289062\n",
      "148 615.2661743164062\n",
      "149 614.8483276367188\n",
      "150 614.4300537109375\n",
      "151 614.0127563476562\n",
      "152 613.5955200195312\n",
      "153 613.1791381835938\n",
      "154 612.762939453125\n",
      "155 612.34814453125\n",
      "156 611.93359375\n",
      "157 611.5191650390625\n",
      "158 611.10498046875\n",
      "159 610.6915893554688\n",
      "160 610.2784423828125\n",
      "161 609.8657836914062\n",
      "162 609.453857421875\n",
      "163 609.0420532226562\n",
      "164 608.6306762695312\n",
      "165 608.220458984375\n",
      "166 607.8095703125\n",
      "167 607.3988037109375\n",
      "168 606.9900512695312\n",
      "169 606.58056640625\n",
      "170 606.1717529296875\n",
      "171 605.7634887695312\n",
      "172 605.3558349609375\n",
      "173 604.9484252929688\n",
      "174 604.5410766601562\n",
      "175 604.1349487304688\n",
      "176 603.7288208007812\n",
      "177 603.3236083984375\n",
      "178 602.9190063476562\n",
      "179 602.5148315429688\n",
      "180 602.1109619140625\n",
      "181 601.708251953125\n",
      "182 601.305908203125\n",
      "183 600.9039916992188\n",
      "184 600.5026245117188\n",
      "185 600.10205078125\n",
      "186 599.7015991210938\n",
      "187 599.301513671875\n",
      "188 598.9022827148438\n",
      "189 598.5035400390625\n",
      "190 598.1049194335938\n",
      "191 597.7069702148438\n",
      "192 597.3096923828125\n",
      "193 596.9131469726562\n",
      "194 596.516845703125\n",
      "195 596.1211547851562\n",
      "196 595.7251586914062\n",
      "197 595.3302001953125\n",
      "198 594.93505859375\n",
      "199 594.5411987304688\n",
      "200 594.1467895507812\n",
      "201 593.7532958984375\n",
      "202 593.3599243164062\n",
      "203 592.9668579101562\n",
      "204 592.5755004882812\n",
      "205 592.1836547851562\n",
      "206 591.7922973632812\n",
      "207 591.4014282226562\n",
      "208 591.0108642578125\n",
      "209 590.6207275390625\n",
      "210 590.2319946289062\n",
      "211 589.8434448242188\n",
      "212 589.4552001953125\n",
      "213 589.0681762695312\n",
      "214 588.6807861328125\n",
      "215 588.2935180664062\n",
      "216 587.9067993164062\n",
      "217 587.5213012695312\n",
      "218 587.1353759765625\n",
      "219 586.7499389648438\n",
      "220 586.3648071289062\n",
      "221 585.98095703125\n",
      "222 585.59716796875\n",
      "223 585.2139892578125\n",
      "224 584.8309936523438\n",
      "225 584.4489135742188\n",
      "226 584.0661010742188\n",
      "227 583.6846313476562\n",
      "228 583.302490234375\n",
      "229 582.9213256835938\n",
      "230 582.54052734375\n",
      "231 582.16064453125\n",
      "232 581.7803955078125\n",
      "233 581.4000244140625\n",
      "234 581.0210571289062\n",
      "235 580.6419677734375\n",
      "236 580.263671875\n",
      "237 579.8859252929688\n",
      "238 579.5077514648438\n",
      "239 579.1301879882812\n",
      "240 578.7532958984375\n",
      "241 578.3767700195312\n",
      "242 578.0003051757812\n",
      "243 577.6253662109375\n",
      "244 577.2506713867188\n",
      "245 576.8768310546875\n",
      "246 576.5027465820312\n",
      "247 576.1293334960938\n",
      "248 575.7565307617188\n",
      "249 575.3844604492188\n",
      "250 575.0118408203125\n",
      "251 574.6407470703125\n",
      "252 574.2691040039062\n",
      "253 573.8984375\n",
      "254 573.5289306640625\n",
      "255 573.1588134765625\n",
      "256 572.7891845703125\n",
      "257 572.420166015625\n",
      "258 572.0510864257812\n",
      "259 571.6832885742188\n",
      "260 571.3156127929688\n",
      "261 570.9489135742188\n",
      "262 570.5814208984375\n",
      "263 570.215576171875\n",
      "264 569.8491821289062\n",
      "265 569.4830322265625\n",
      "266 569.1181640625\n",
      "267 568.7535400390625\n",
      "268 568.3889770507812\n",
      "269 568.0253295898438\n",
      "270 567.6614379882812\n",
      "271 567.2977294921875\n",
      "272 566.9348754882812\n",
      "273 566.572021484375\n",
      "274 566.2095336914062\n",
      "275 565.8471069335938\n",
      "276 565.4848022460938\n",
      "277 565.12353515625\n",
      "278 564.7621459960938\n",
      "279 564.4014892578125\n",
      "280 564.0414428710938\n",
      "281 563.6817626953125\n",
      "282 563.322998046875\n",
      "283 562.9645385742188\n",
      "284 562.6057739257812\n",
      "285 562.2478637695312\n",
      "286 561.8897705078125\n",
      "287 561.5322265625\n",
      "288 561.1744995117188\n",
      "289 560.8177490234375\n",
      "290 560.461181640625\n",
      "291 560.1044921875\n",
      "292 559.748291015625\n",
      "293 559.39306640625\n",
      "294 559.0386352539062\n",
      "295 558.6842651367188\n",
      "296 558.3301391601562\n",
      "297 557.9769897460938\n",
      "298 557.6254272460938\n",
      "299 557.2730712890625\n",
      "300 556.9218139648438\n",
      "301 556.5706176757812\n",
      "302 556.2200317382812\n",
      "303 555.869873046875\n",
      "304 555.5202026367188\n",
      "305 555.170654296875\n",
      "306 554.8215942382812\n",
      "307 554.472900390625\n",
      "308 554.1243286132812\n",
      "309 553.7764892578125\n",
      "310 553.4288940429688\n",
      "311 553.08154296875\n",
      "312 552.735107421875\n",
      "313 552.3887329101562\n",
      "314 552.0430297851562\n",
      "315 551.6969604492188\n",
      "316 551.3511962890625\n",
      "317 551.0062866210938\n",
      "318 550.6616821289062\n",
      "319 550.31689453125\n",
      "320 549.9722290039062\n",
      "321 549.6280517578125\n",
      "322 549.284912109375\n",
      "323 548.9415893554688\n",
      "324 548.5982666015625\n",
      "325 548.2562255859375\n",
      "326 547.9142456054688\n",
      "327 547.572265625\n",
      "328 547.2308349609375\n",
      "329 546.8901977539062\n",
      "330 546.5496826171875\n",
      "331 546.2094116210938\n",
      "332 545.8696899414062\n",
      "333 545.5298461914062\n",
      "334 545.1903076171875\n",
      "335 544.8517456054688\n",
      "336 544.5137939453125\n",
      "337 544.1766357421875\n",
      "338 543.8397216796875\n",
      "339 543.5028686523438\n",
      "340 543.1663208007812\n",
      "341 542.830078125\n",
      "342 542.494384765625\n",
      "343 542.1589965820312\n",
      "344 541.82421875\n",
      "345 541.489501953125\n",
      "346 541.1555786132812\n",
      "347 540.8209838867188\n",
      "348 540.4873657226562\n",
      "349 540.1537475585938\n",
      "350 539.8203125\n",
      "351 539.4876708984375\n",
      "352 539.1554565429688\n",
      "353 538.8225708007812\n",
      "354 538.4906005859375\n",
      "355 538.1587524414062\n",
      "356 537.8274536132812\n",
      "357 537.4962768554688\n",
      "358 537.1646118164062\n",
      "359 536.8341674804688\n",
      "360 536.5037231445312\n",
      "361 536.1734619140625\n",
      "362 535.8429565429688\n",
      "363 535.513427734375\n",
      "364 535.183837890625\n",
      "365 534.8546142578125\n",
      "366 534.5261840820312\n",
      "367 534.1976928710938\n",
      "368 533.869873046875\n",
      "369 533.5415649414062\n",
      "370 533.2139892578125\n",
      "371 532.8865966796875\n",
      "372 532.5602416992188\n",
      "373 532.2329711914062\n",
      "374 531.9064331054688\n",
      "375 531.5800170898438\n",
      "376 531.253662109375\n",
      "377 530.9281005859375\n",
      "378 530.6019897460938\n",
      "379 530.2769775390625\n",
      "380 529.9509887695312\n",
      "381 529.6262817382812\n",
      "382 529.3015747070312\n",
      "383 528.9766845703125\n",
      "384 528.6526489257812\n",
      "385 528.3280029296875\n",
      "386 528.004150390625\n",
      "387 527.6810302734375\n",
      "388 527.3579711914062\n",
      "389 527.0352783203125\n",
      "390 526.712890625\n",
      "391 526.3905639648438\n",
      "392 526.0687255859375\n",
      "393 525.7471923828125\n",
      "394 525.4256591796875\n",
      "395 525.1053466796875\n",
      "396 524.7848510742188\n",
      "397 524.4649047851562\n",
      "398 524.1452026367188\n",
      "399 523.8272094726562\n",
      "400 523.5095825195312\n",
      "401 523.1922607421875\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "402 522.8744506835938\n",
      "403 522.5575561523438\n",
      "404 522.24072265625\n",
      "405 521.9238891601562\n",
      "406 521.6070556640625\n",
      "407 521.2906494140625\n",
      "408 520.9743041992188\n",
      "409 520.6585693359375\n",
      "410 520.3432006835938\n",
      "411 520.028076171875\n",
      "412 519.7131958007812\n",
      "413 519.3980712890625\n",
      "414 519.0833740234375\n",
      "415 518.769287109375\n",
      "416 518.4549560546875\n",
      "417 518.1414184570312\n",
      "418 517.8278198242188\n",
      "419 517.5144653320312\n",
      "420 517.2015991210938\n",
      "421 516.8885498046875\n",
      "422 516.575927734375\n",
      "423 516.2636108398438\n",
      "424 515.951416015625\n",
      "425 515.6396484375\n",
      "426 515.327880859375\n",
      "427 515.0166015625\n",
      "428 514.7051391601562\n",
      "429 514.3941650390625\n",
      "430 514.0833740234375\n",
      "431 513.7724609375\n",
      "432 513.4624633789062\n",
      "433 513.152099609375\n",
      "434 512.8423461914062\n",
      "435 512.5328979492188\n",
      "436 512.2235107421875\n",
      "437 511.9145812988281\n",
      "438 511.60589599609375\n",
      "439 511.2974548339844\n",
      "440 510.98907470703125\n",
      "441 510.6817626953125\n",
      "442 510.374267578125\n",
      "443 510.067138671875\n",
      "444 509.7607421875\n",
      "445 509.4538269042969\n",
      "446 509.1470947265625\n",
      "447 508.8416442871094\n",
      "448 508.5361022949219\n",
      "449 508.2306213378906\n",
      "450 507.9258728027344\n",
      "451 507.62066650390625\n",
      "452 507.31591796875\n",
      "453 507.011474609375\n",
      "454 506.7081298828125\n",
      "455 506.4047546386719\n",
      "456 506.1019592285156\n",
      "457 505.79840087890625\n",
      "458 505.4956970214844\n",
      "459 505.19293212890625\n",
      "460 504.8902893066406\n",
      "461 504.5878601074219\n",
      "462 504.2854309082031\n",
      "463 503.98388671875\n",
      "464 503.6827087402344\n",
      "465 503.38153076171875\n",
      "466 503.0809020996094\n",
      "467 502.7801513671875\n",
      "468 502.4800109863281\n",
      "469 502.1800842285156\n",
      "470 501.88055419921875\n",
      "471 501.5813293457031\n",
      "472 501.2817077636719\n",
      "473 500.9828796386719\n",
      "474 500.68377685546875\n",
      "475 500.3855895996094\n",
      "476 500.0868225097656\n",
      "477 499.78826904296875\n",
      "478 499.4908447265625\n",
      "479 499.1929626464844\n",
      "480 498.89544677734375\n",
      "481 498.5979309082031\n",
      "482 498.3008117675781\n",
      "483 498.0042724609375\n",
      "484 497.7075500488281\n",
      "485 497.4110107421875\n",
      "486 497.1147766113281\n",
      "487 496.8188781738281\n",
      "488 496.5232849121094\n",
      "489 496.2281799316406\n",
      "490 495.9330749511719\n",
      "491 495.6383056640625\n",
      "492 495.343994140625\n",
      "493 495.050048828125\n",
      "494 494.75592041015625\n",
      "495 494.4626159667969\n",
      "496 494.16937255859375\n",
      "497 493.8759765625\n",
      "498 493.5833435058594\n",
      "499 493.2904357910156\n"
     ]
    }
   ],
   "source": [
    "device = torch.device('cpu')\n",
    "N, D_in, H, D_out = 64, 1000, 100, 10\n",
    "learning_rate = 1e-6\n",
    "\n",
    "x = torch.randn(N, D_in, device=device)\n",
    "y = torch.randn(N, D_out, device=device)\n",
    "\n",
    "model = torch.nn.Sequential(\n",
    "            torch.nn.Linear(D_in, H),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(H, D_out),\n",
    "        ).to(device)\n",
    "\n",
    "loss_fn = torch.nn.MSELoss(size_average=False)\n",
    "\n",
    "for t in range(500):\n",
    "    y_pred = model(x)\n",
    "    loss = loss_fn(y_pred, y)\n",
    "    print(t, loss.item())\n",
    "    \n",
    "    model.zero_grad()\n",
    "    loss.backward()\n",
    "    with torch.no_grad():\n",
    "        for param in model.parameters():\n",
    "            param.data -= learning_rate * param.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
